{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.autograd import Variable\n",
    "import torchvision\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset\n",
    "from torchvision import transforms, datasets, models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#修改图片为3*224*224\n",
    "data_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])\n",
    "#从data/train读入训练集\n",
    "train_img = torchvision.datasets.ImageFolder('data/train',\n",
    "                                            transform=data_transform\n",
    "                                            )\n",
    "test_img = torchvision.datasets.ImageFolder('data/val',\n",
    "                                            transform=data_transform\n",
    "                                            )\n",
    "#从data/val读入测试集\n",
    "train_data = torch.utils.data.DataLoader(train_img, batch_size=4,shuffle=True)\n",
    "test_data = torch.utils.data.DataLoader(test_img, batch_size=4,shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'nn' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-2-e0b0117c86ac>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mclass\u001b[0m \u001b[0mNet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      2\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      3\u001b[0m         \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mNet\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv1\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mConv2d\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m6\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m5\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmaxpool\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mMaxPool2d\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'nn' is not defined"
     ]
    }
   ],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Sequential(\n",
    "            nn.Conv2d(3,16,kernel_size=5,padding=1),\n",
    "            nn.BatchNorm2d(16),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(2)\n",
    "        )\n",
    "        self.conv2 = nn.Sequential(\n",
    "            nn.Conv2d(16,32,kernel_size=2,padding=1),\n",
    "            nn.BatchNorm2d(32),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(2)\n",
    "        )\n",
    "        self.fc = nn.Linear(64,2)\n",
    "#卷积+卷积+全连接\n",
    "        \n",
    "#前向传播\n",
    "    def forward(self,x):\n",
    "        out = self.conv1(x)\n",
    "        out = self.conv2(out)\n",
    "        out = out.view(out.size(0),-1)\n",
    "        return out\n",
    "\n",
    "\n",
    "net = Net()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模型加载到cpu或gpu\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print('device',device)\n",
    "net.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#代价函数与优化器\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)\n",
    "\n",
    "# 训练网络\n",
    "for epoch in range(2):\n",
    "    running_loss = 0.0\n",
    "    for i,data in enumerate(train_data,0):\n",
    "        inputs, labels = data[0].to(device), data[1].to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "#统计信息\n",
    "        running_loss += loss.item()\n",
    "        if i % 200 ==199:\n",
    "            print('[%d,%5d] loss:%.3f'%(epoch + 1,i+1,running_loss /2000))\n",
    "            running_loss = 0.0\n",
    "\n",
    "print('Finished Training')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#保存模型\n",
    "PATH_checkpoint = './data/model.pth'\n",
    "torch.save(net.state_dict(),PATH_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载模型\n",
    "net_predict = Net()\n",
    "net_predict.load_state_dict(torch.load(PATH_checkpoint))\n",
    "\n",
    "#测试集验证\n",
    "class_correct = list(0. for i in range(10))\n",
    "class_total = list(0. for i in range(10))\n",
    "with torch.no_grad():\n",
    "    for data in test_data:\n",
    "        images, labels = data\n",
    "        outputs = net_predict(images)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        c = (predicted == labels).squeeze()\n",
    "        for i in range(3):\n",
    "            label = labels[i]\n",
    "            class_correct[label] += c[i].item()\n",
    "            class_total[label] += 1\n",
    "\n",
    "classes = ('cuke','lettuce','lotus_root'  )\n",
    "\n",
    "for i in range(3):\n",
    "    print('Accuracy of %5s : %2d %%' % (\n",
    "        classes[i], 100 * class_correct[i] / class_total[i]))\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
